{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import transforms, utils\n",
    "\n",
    "import time\n",
    "import numpy as np\n",
    "import os\n",
    "import pickle\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "plt.ion()\n",
    "\n",
    "from model.generator import *\n",
    "from model.discriminator import NLayerDiscriminator as Discriminator\n",
    "from utils.dataloader import YouTubePose\n",
    "from utils.loss_functions import lossIdentity, lossShape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "dataset_dir = './Dataset/'\n",
    "checkpoint_path = \"./model_checkpoints/\"\n",
    "batch_size = 4\n",
    "epochs = 10\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "learning_rate_generator = 3e-4\n",
    "learning_rate_discriminator = 0.1\n",
    "alpha = 8\n",
    "beta = 0.001"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "with open('./dataset_lists/train_datapoint_triplets.pkl', 'rb') as f:\n",
    "    datapoint_pairs = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "with open('./dataset_lists/train_shapeLoss_pairs.pkl', 'rb') as f:\n",
    "    shapeLoss_datapoint_pairs = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "transform = transforms.Compose([\n",
    "    transforms.Resize((256, 256)),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_dataset = YouTubePose(datapoint_pairs, shapeLoss_datapoint_pairs, dataset_dir, transform)\n",
    "train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,\n",
    "                                             num_workers=0)\n",
    "dataset_sizes = [len(train_dataset)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "generator = Generator(ResidualBlock)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "discriminator = Discriminator(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "generator = generator.to(device)\n",
    "discriminator = discriminator.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "optimizer_gen = optim.Adam(generator.parameters(), lr = learning_rate_generator)\n",
    "optimizer_disc = optim.SGD(discriminator.parameters(), lr = learning_rate_discriminator, momentum=0.9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def save_checkpoint(state, dirpath, epoch):\n",
    "    filename = 'checkpoint-{}.ckpt'.format(epoch)\n",
    "    checkpoint_path = os.path.join(dirpath, filename)\n",
    "    torch.save(state, checkpoint_path)\n",
    "    print('--- checkpoint saved to ' + str(checkpoint_path) + ' ---')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def train_model(gen, disc, loss_i, loss_s, optimizer_gen, optimizer_disc, alpha = 1, beta = 1, num_epochs = 10):\n",
    "    for epoch in range(num_epochs):\n",
    "        print(\"Epoch {}/{}\".format(epoch, num_epochs - 1))\n",
    "        print('-'*10)\n",
    "        dataloader = train_dataloader\n",
    "        gen.train()\n",
    "        disc.train()\n",
    "        since = time.time()\n",
    "        running_loss_iden = 0.0\n",
    "        running_loss_s1 = 0.0\n",
    "        running_loss_s2a = 0.0\n",
    "        running_loss_s2b = 0.0\n",
    "        running_loss = 0.0\n",
    "        \n",
    "        for i_batch, sample_batched in enumerate(dataloader):\n",
    "            x_gen, y, x_dis = sample_batched['x_gen'], sample_batched['y'], sample_batched['x_dis']\n",
    "            iden_1, iden_2 = sample_batched['iden_1'], sample_batched['iden_2']\n",
    "            x_gen = x_gen.to(device)\n",
    "            y = y.to(device)\n",
    "            x_dis = x_dis.to(device)\n",
    "            iden_1 = iden_1.to(device)\n",
    "            iden_2 = iden_2.to(device)\n",
    "            \n",
    "            optimizer_gen.zero_grad()\n",
    "            optimizer_disc.zero_grad()\n",
    "            \n",
    "            with torch.set_grad_enabled(True):\n",
    "                x_generated = gen(x_gen, y)\n",
    "                fake_op, fake_pooled_op = disc(x_gen, x_generated)\n",
    "                real_op, real_pooled_op = disc(x_gen, x_dis)\n",
    "                loss_identity_gen = -loss_i(real_pooled_op, fake_pooled_op)\n",
    "                loss_identity_gen.backward(retain_graph=True)\n",
    "                optimizer_gen.step()\n",
    "                \n",
    "                optimizer_disc.zero_grad()\n",
    "                loss_identity_disc = loss_i(real_op, fake_op)\n",
    "                loss_identity_disc.backward(retain_graph=True)\n",
    "                optimizer_disc.step()\n",
    "\n",
    "                optimizer_gen.zero_grad()\n",
    "                optimizer_disc.zero_grad()\n",
    "                x_ls2a = gen(y, x_generated)\n",
    "                x_ls2b = gen(x_generated, y)\n",
    "\n",
    "                loss_s2a = loss_s(y, x_ls2a)\n",
    "                loss_s2b = loss_s(x_generated, x_ls2b)\n",
    "                loss_s2 = loss_s2a + loss_s2b\n",
    "\n",
    "                loss_s2.backward()\n",
    "                optimizer_gen.step()\n",
    "\n",
    "                optimizer_gen.zero_grad()\n",
    "                optimizer_disc.zero_grad()\n",
    "                \n",
    "                x_ls1 = generator(iden_1, iden_2)\n",
    "\n",
    "                loss_s1 = loss_s(iden_2, x_ls1)\n",
    "                loss_s1.backward()\n",
    "                optimizer_gen.step()\n",
    "            running_loss_iden += loss_identity_disc.item() * x_gen.size(0)\n",
    "            running_loss_s1 += loss_s1.item() * x_gen.size(0)\n",
    "            running_loss_s2a += loss_s2a.item() * x_gen.size(0) \n",
    "            running_loss_s2b += loss_s2b.item() * x_gen.size(0)\n",
    "            running_loss = running_loss_iden +  beta * (running_loss_s1 + alpha * (running_loss_s2a + running_loss_s2b))\n",
    "        epoch_loss_iden = running_loss_iden / dataset_sizes[0]\n",
    "        epoch_loss_s1 = running_loss_s1 / dataset_sizes[0]\n",
    "        epoch_loss_s2a = running_loss_s2a / dataset_sizes[0]\n",
    "        epoch_loss_s2b = running_loss_s2a / dataset_sizes[0]\n",
    "        epoch_loss = running_loss / dataset_sizes[0]\n",
    "        print('Identity Loss: {:.4f} Loss Shape1: {:.4f} Loss Shape2a: {:.4f} Loss Shape2b: {:.4f}'.format(epoch_loss_iden, epoch_loss_s1,\n",
    "                                           epoch_loss_s2a, epoch_loss_s2b))\n",
    "        print('Epoch Loss: {:.4f}'.format(epoch_loss))\n",
    "        \n",
    "        save_checkpoint({\n",
    "            'epoch': epoch + 1,\n",
    "            'gen_state_dict': gen.state_dict(),\n",
    "            'disc_state_dict': disc.state_dict(),\n",
    "            'gen_opt': optimizer_gen.state_dict(),\n",
    "            'disc_opt': optimizer_disc.state_dict()\n",
    "        }, checkpoint_path, epoch + 1)\n",
    "        print('Time taken by epoch: {: .0f}m {:0f}s'.format((time.time() - since) // 60, (time.time() - since) % 60))\n",
    "        print()\n",
    "        since = time.time()\n",
    "\n",
    "    return gen, disc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0/9\n",
      "----------\n"
     ]
    }
   ],
   "source": [
    "generator, discriminator = train_model(generator, discriminator, lossIdentity, lossShape, optimizer_gen, optimizer_disc, alpha=alpha, beta=beta, num_epochs=epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
